/*
 * Copyright (c) 2011-2022, baomidou (jobob@qq.com).
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.baomidou.mybatisplus.core.injector.methods;

import com.baomidou.mybatisplus.annotation.DbType;
import com.baomidou.mybatisplus.annotation.IdType;
import com.baomidou.mybatisplus.core.enums.SqlMethod;
import com.baomidou.mybatisplus.core.injector.AbstractMethod;
import com.baomidou.mybatisplus.core.metadata.TableInfo;
import com.baomidou.mybatisplus.core.metadata.TableInfoHelper;
import com.baomidou.mybatisplus.core.toolkit.Constants;
import com.baomidou.mybatisplus.core.toolkit.StringUtils;
import com.baomidou.mybatisplus.core.toolkit.sql.SqlScriptUtils;
import org.apache.ibatis.executor.keygen.Jdbc3KeyGenerator;
import org.apache.ibatis.executor.keygen.KeyGenerator;
import org.apache.ibatis.executor.keygen.NoKeyGenerator;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlSource;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

/**
 * 批量插入
 * @author wanglei
 * @since 2022-03-23
 */
public class InsertBatch extends AbstractMethod {

    /**
     * key 数据库类型 value 数据库批量插入sql生成方法
     */
    private static Map<DbType, GenSql> DB_FUNCTIONS = new HashMap<>();

    /**
     * 无法为用户自动生成批量插入方法的数据库类型，给用户提示
     */
    public static final Set<DbType> WARN_SET = new HashSet<>();

    public InsertBatch() {
        super(SqlMethod.INSERT_BATCH.getMethod());
    }

    /**
     * @param name 方法名
     * @since 3.5.0
     */
    public InsertBatch(String name) {
        super(name);
    }

    @Override
    public MappedStatement injectMappedStatement(Class<?> mapperClass, Class<?> modelClass, TableInfo tableInfo) {
        KeyGenerator keyGenerator = NoKeyGenerator.INSTANCE;
        SqlMethod sqlMethod = SqlMethod.INSERT_BATCH;
        String columnScript = "";
        String valuesScript = "";
        String values = "";
        //碰到不支持的数据库，给一个warn给用户
        if (!DB_FUNCTIONS.containsKey(tableInfo.getDbType()) && !WARN_SET.contains(tableInfo.getDbType())) {
            logger.warn("db type: " + tableInfo.getDbType().getDb() + " does not support batch insert");
            WARN_SET.add(tableInfo.getDbType());
        } else if (DB_FUNCTIONS.containsKey(tableInfo.getDbType())) {
            Map<SqlLocation, String> sqlMap = DB_FUNCTIONS.get(tableInfo.getDbType()).genSql(mapperClass, modelClass, tableInfo);
            columnScript = sqlMap.get(SqlLocation.COLUMN);
            valuesScript = sqlMap.get(SqlLocation.VALUE);
            values = sqlMap.get(SqlLocation.VALUES);
        }
        String keyProperty = null;
        String keyColumn = null;
        // 表包含主键处理逻辑,如果不包含主键当普通字段处理
        if (StringUtils.isNotBlank(tableInfo.getKeyProperty())) {
            if (tableInfo.getIdType() == IdType.AUTO) {
                /* 自增主键 */
                keyGenerator = Jdbc3KeyGenerator.INSTANCE;
                keyProperty = tableInfo.getKeyProperty();
                keyColumn = tableInfo.getKeyColumn();
            }
        }
        String sql = String.format(sqlMethod.getSql(), tableInfo.getTableName(), columnScript, values,valuesScript);
        SqlSource sqlSource = languageDriver.createSqlSource(configuration, sql, modelClass);
        return this.addInsertMappedStatement(mapperClass, modelClass, getMethod(sqlMethod), sqlSource, keyGenerator, keyProperty, keyColumn);
    }


    /**
     * 公共的sql，大多数数据库都使用此方式批量插入
     * 大概生成的sql如下：
     * INSERT INTO
     * student_table
     * (
     * name,
     * age,
     * salary
     * )
     * VALUES
     * <foreach collection="list" item="item" index="index" separator=",">
     * (
     * #{item.name},
     * #{item.age},
     * #{item.salary}
     * )
     * </foreach>
     *
     * @param mapperClass mapper类
     * @param modelClass  model类
     * @param tableInfo   表信息
     * @return sql数据
     */
    public Map<SqlLocation, String> genCommonSql(Class<?> mapperClass, Class<?> modelClass, TableInfo tableInfo) {
        Map<SqlLocation, String> resultSqlMap = new HashMap<>();
        resultSqlMap.put(SqlLocation.VALUES, " VALUES ");
        resultSqlMap.put(SqlLocation.COLUMN, SqlScriptUtils.convertTrim(tableInfo.getAllInsertSqlColumn(null),
            LEFT_BRACKET, RIGHT_BRACKET, null, COMMA));
        resultSqlMap.put(SqlLocation.VALUE, SqlScriptUtils.convertForeach(SqlScriptUtils.convertTrim(tableInfo.getAllInsertSqlProperty(Constants.ITEM + DOT),
            LEFT_BRACKET, RIGHT_BRACKET, null, COMMA), Constants.COLLECTION, Constants.INDEX, Constants.ITEM, COMMA));
        return resultSqlMap;
    }

    /**
     * 生成oracle的批量插入sql
     * INSERT INTO user <trim prefix="(" suffix=")" suffixOverrides=",">
     * user_id,
     * name,
     * age,
     * sex,
     * school_id,
     * </trim>
     * select
     * user_id , name , age,school_id from (
     * <foreach collection="coll"  item="item" index="index" separator="union all">
     * select
     * <trim  suffixOverrides=",">
     * #{item.col1} as user_id  ,#{item.name} as name,#{item.age} as age,#{item.schoolId}  as school_id,
     * </trim>
     * from dual
     * </foreach>
     * )
     *
     * @param mapperClass mapper类
     * @param modelClass  model类
     * @param tableInfo   表信息
     * @return sql数据
     */
    public Map<SqlLocation, String> genOracleSql(Class<?> mapperClass, Class<?> modelClass, TableInfo tableInfo) {
        Map<SqlLocation, String> resultSqlMap = new HashMap<>();
        resultSqlMap.put(SqlLocation.COLUMN, SqlScriptUtils.convertTrim(tableInfo.getAllInsertSqlColumn(null),
            LEFT_BRACKET, RIGHT_BRACKET, null, COMMA));

        //oracle 批处理不需要values 不然会报错
        resultSqlMap.put(SqlLocation.VALUES, "  ");


        //col1,col2,col3,
        String selectColumns = tableInfo.getAllInsertSqlColumn(null);

        //(col1,col2,col3)
        resultSqlMap.put(SqlLocation.COLUMN, SqlScriptUtils.convertTrim(selectColumns,
            LEFT_BRACKET, RIGHT_BRACKET, null, COMMA));

        //select col1,col2,col3 from (
        String insideSelect = SELECT + SqlScriptUtils.convertTrim(tableInfo.getAllInsertSqlColumn(null),
            null, null, null, COMMA) + FROM + LEFT_BRACKET;


        //col1,col2,col3,
        String[] columns = selectColumns.split(COMMA);

        //item.col1,item.col2,item.col3,  对应的字段
        String sqlProperty = tableInfo.getAllInsertSqlProperty(Constants.ITEM + DOT);

        //拼接为  item.col1 as col1,item.col2 as col2,item.col3 as col3,
        StringBuilder selectValueField = new StringBuilder();

        String[] sqlPropertys = sqlProperty.split(COMMA);

        //拼接as
        for (int i = 0; i < sqlPropertys.length; i++) {
            if (StringUtils.isNotBlank(sqlPropertys[i])) {
                selectValueField.append(SqlScriptUtils.convertChoose(parsePropertyOracleNotNull(sqlPropertys[i]),sqlPropertys[i],"null"))
                    .append(AS).append(columns[i]).append(COMMA);
            }
        }

        String selectFromdual = SqlScriptUtils.convertForeach( SELECT + SqlScriptUtils.convertTrim(selectValueField.toString(),
            null, null, null, COMMA) + FROM + DUAL, Constants.COLLECTION, Constants.INDEX, Constants.ITEM, UNION_ALL);
        resultSqlMap.put(SqlLocation.VALUE, insideSelect + selectFromdual + RIGHT_BRACKET);
        return resultSqlMap;
    }

    /**
     * #{item.col} 转换为 item.col
     * @param property 字段
     * @return 转换后的字段
     */
    private String parsePropertyOracleNotNull(String property) {
        return property.replaceAll("#\\{","").replaceAll("}", "") + "!=null";
    }


    /**
     * 初始化
     *
     * @return 本身
     */
    public InsertBatch init() {
        //下面几种数据库的批量插入方式相同
        //mysql 测试通过
        DB_FUNCTIONS.put(DbType.MYSQL, this::genCommonSql);
        //postgresql 测试通过
        DB_FUNCTIONS.put(DbType.POSTGRE_SQL, this::genCommonSql);
        //SQL Server 2019 Express 测试通过（在测试的配置类中，设置了GlobalConfig的dbType为sqlserver）
        DB_FUNCTIONS.put(DbType.SQL_SERVER, this::genCommonSql);
        // 已经测试11g 12c应该通用
        DB_FUNCTIONS.put(DbType.ORACLE, this::genCommonSql);
        DB_FUNCTIONS.put(DbType.ORACLE_12C, this::genOracleSql);
        return this;
    }

    /**
     * 用来给不同的数据库生成不同的sql的接口
     *
     * @author wanglei
     * @since 2022-03-22
     */
    @FunctionalInterface
    public interface GenSql {
        /**
         * 生成sql
         *
         * @param mapperClass mapper类
         * @param modelClass  model类
         * @param tableInfo   表信息
         * @return sql
         */
        Map<SqlLocation, String> genSql(Class<?> mapperClass, Class<?> modelClass, TableInfo tableInfo);
    }

    /**
     * sql片段位置
     *
     * @author wanglei
     * @since 2022-03-22
     */
    enum SqlLocation {
        COLUMN,
        VALUES,
        VALUE;
    }

}
